import numpy as np
import scipy.sparse as sp
import torch

# 把原本只有一个数字的label转换成独热编码
def encode_onehot(labels):
    '''
    :param labels: N*1
    :return: labels_onehot N*len(classes) 每一行只有他的标签处才为0
    '''
    classes = set(labels)
    classes_dict = {
        c: np.identity(len(classes))[i, :]  for i, c in enumerate(classes)
    }
    # 这里很聪明，因为他直接把E的对应的那一行取出来，就是one-hot 不过我觉得没有必要每一次都生成一个E，所以实际编码上有点多余

    labels_onehot = np.array(list(map(classes_dict.get, labels)),
                             dtype=np.int32)
    return labels_onehot


def load_data(path="../data/cora/", dataset="cora"):
    '''
    :param path:
    :param dataset:
    :return: adj # N*N
    :return features N*1433
    :return labels N*1 [0,7)
    :return idx_train
    :return idx_val
    :return idx_test
    '''
    """Load citation network dataset (cora only for now)"""

    print('Loading {} dataset...'.format(dataset))

    idx_features_labels = np.genfromtxt("{}{}.content".format(path, dataset),
                                        dtype=np.dtype(str))

    # Nx1435
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
    # Nx 1433

    # <paper_id> <word_attributes> + <class_label> 所以是从1到-1，这里的features 其实是词汇是否出现的意思

    labels = encode_onehot(idx_features_labels[:, -1]) # 把文章的标签拿出来 N*1

    # 经过独热编码处理之后的labels为 N*len(classes),也就是Nx7

    # build graph 接下来开始构建图
    idx = np.array(idx_features_labels[:, 0], dtype=np.int32) # 先把paper-id取出来
    idx_map = {j: i for i, j in enumerate(idx)}

    # paper-id 与 array中的index做了个对应

    edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),
                                    dtype=np.int32)

    # 上面那个读取了无序的边

    edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
                     dtype=np.int32).reshape(edges_unordered.shape)

    # 利用之前的dict做了个index的转换，从paper-id转换成这里的index,先展成一维，方便后续处理，再reshape回来就好了

    adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                        shape=(labels.shape[0], labels.shape[0]),
                        dtype=np.float32)

    # 上面这一步是真的骚，实际就是快速构建一个系数矩阵，第一个参数是数据数组，第二个参数是(row,col)数组

    # build symmetric adjacency matrix
    # 构建一个对称邻接矩阵，你这里本来是单向的，因为是引用和被引用的关系，但是GCN建模的是无向图，所以就对称下

    adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) # 转置相加然后在减去中间的

    features = normalize(features)  # 特征归一化

    adj = normalize(adj + sp.eye(adj.shape[0])) # 邻接矩阵归一化 eye返回的是一个单位阵，默认是方阵，加上了自连接

    idx_train = range(140)
    idx_val = range(200, 500)
    idx_test = range(500, 1500)

    features = torch.FloatTensor(np.array(features.todense()))
    labels = torch.LongTensor(np.where(labels)[1]) # 真的不是很懂上面处理成one-hot要干嘛，现在我明白了，因为本来是个字符串你敢信
    adj = sparse_mx_to_torch_sparse_tensor(adj)

    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)

    return adj, features, labels, idx_train, idx_val, idx_test

# 貌似是行归一化，就是每一行的每个值除以每一行的和而已
def normalize(mx):
    '''
    :param mx: row x col
    :return:
    '''
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1)) # row 每一行的和
    r_inv = np.power(rowsum, -1).flatten() # 每一行的合的倒数
    r_inv[np.isinf(r_inv)] = 0. # 如果这一行的导数是无穷，那就置为0
    r_mat_inv = sp.diags(r_inv) # 这个函数可真是够复杂的，自己查阅吧，大概意思是将一个向量变成一个对角阵
    mx = r_mat_inv.dot(mx)
    return mx


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)
